package com.tyron.completion.java.util;

import com.github.javaparser.Position;
import com.github.javaparser.Range;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.ImportDeclaration;
import com.github.javaparser.ast.Modifier;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.PackageDeclaration;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.AnnotationExpr;
import com.github.javaparser.ast.nodeTypes.NodeWithAnnotations;
import com.github.javaparser.ast.nodeTypes.NodeWithModifiers;
import com.github.javaparser.ast.nodeTypes.NodeWithRange;
import com.github.javaparser.ast.nodeTypes.NodeWithTypeParameters;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.PrimitiveType;
import com.github.javaparser.ast.type.ReferenceType;
import com.github.javaparser.ast.type.Type;
import com.github.javaparser.ast.type.TypeParameter;
import com.github.javaparser.ast.type.UnknownType;
import com.github.javaparser.ast.type.VoidType;
import com.github.javaparser.ast.type.WildcardType;
import com.tyron.completion.java.compiler.ParseTask;

import com.sun.source.tree.AnnotationTree;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.tree.IdentifierTree;
import com.sun.source.tree.ImportTree;
import com.sun.source.tree.LineMap;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ModifiersTree;
import com.sun.source.tree.ParameterizedTypeTree;
import com.sun.source.tree.PrimitiveTypeTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.TypeParameterTree;
import com.sun.source.tree.WildcardTree;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.code.BoundKind;
import com.sun.tools.javac.tree.DocCommentTable;
import com.sun.tools.javac.tree.EndPosTable;
import com.sun.tools.javac.tree.JCTree;

import java.util.List;

/**
 * Converts the CompilationUnitTree generated by javac into {@link CompilationUnit}.
 *
 */
public class CompilationUnitConverter extends TreeScanner<Void, Node> {

    /**
     * Delegate for converting an index based position into line and column based position.
     *
     * This is needed because javac treats tabs as 8 spaces and it causes inconsistencies
     * with editors that treats tabs as 4 spaces
     */
    public interface LineColumnCallback {
        int getLine(int pos);
        int getColumn(int pos);
    }

    private CompilationUnit mCompilationUnit = null;
    private LineMap mLineMap;
    private EndPosTable mEndPosTable = null;
    private DocCommentTable mDocComments = null;

    private final ParseTask mParseTask;
    private final String mContents;
    private final LineColumnCallback mCallback;

    public CompilationUnitConverter(ParseTask task, String contents, LineColumnCallback callback) {
        mParseTask = task;
        mContents = contents;
        mCallback = callback;
    }

    public CompilationUnit startScan() {
        scan(mParseTask.root, null);
        return mCompilationUnit;
    }

    @Override
    public Void visitCompilationUnit(CompilationUnitTree node, Node current) {
        JCTree.JCCompilationUnit unit = (JCTree.JCCompilationUnit) node;
        mLineMap = unit.getLineMap();
        mEndPosTable = unit.endPositions;
        mDocComments = unit.docComments;
        mCompilationUnit = new CompilationUnit();
        addNodeRange(unit, mCompilationUnit);
        if (unit.getPackageName() != null) {
            PackageDeclaration packageDeclaration = new PackageDeclaration();
            packageDeclaration.setName(unit.getPackageName().toString());
            addNodeRange(node.getPackage(), packageDeclaration);
            mCompilationUnit.setPackageDeclaration(packageDeclaration);
        }

        for (ImportTree importTree : unit.getImports()) {
            String name = importTree.getQualifiedIdentifier().toString();
            if (name.isEmpty()) {
                continue;
            }
            boolean isAsterisk = name.endsWith("*");
            boolean isStatic = importTree.isStatic();
            ImportDeclaration importDeclaration = new ImportDeclaration(name, isStatic, isAsterisk);
            mCompilationUnit.addImport(importDeclaration);
            addNodeRange(importTree, importDeclaration);
        }

        for (Tree decl : node.getTypeDecls()) {
            this.scan(decl, mCompilationUnit);
        }
        mEndPosTable = null;
        return null;
    }

    @Override
    public Void visitClass(ClassTree classTree, Node current) {
        ClassOrInterfaceDeclaration declaration = new ClassOrInterfaceDeclaration();
        addNodeRange(classTree, declaration);

        if (classTree.getModifiers() != null) {
            scan(classTree.getModifiers(), declaration);
        }

        Tree extendsClause = classTree.getExtendsClause();
        if (extendsClause != null) {
            ClassOrInterfaceType extended =
                    JavaParserTypesUtil.toClassOrInterfaceType(extendsClause);
            addNodeRange(extendsClause, extended);
            declaration.setExtendedTypes(NodeList.nodeList(extended));
        }

        List<? extends Tree> implementsClause = classTree.getImplementsClause();
        for (Tree tree : implementsClause) {
            ClassOrInterfaceType implemented =
                    JavaParserTypesUtil.toClassOrInterfaceType(tree);
            addNodeRange(tree, implemented);
            declaration.addImplementedType(implemented);
        }

        int classStart = ((JCTree.JCClassDecl) classTree).getStartPosition();
        int classEnd = mEndPosTable.getEndPos((JCTree) classTree);
        String contents = mContents.substring(classStart, classEnd);

        String name = classTree.getSimpleName().toString();
        declaration.setName(name);
        declaration.getName().setRange(getRangeForName(contents, name, classStart));

        if (classTree.getTypeParameters() != null) {
            scan(classTree.getTypeParameters(), declaration);
        }

        for (Tree member : classTree.getMembers()) {
            scan(member, declaration);
        }

        ((CompilationUnit) current).addType(declaration);
        return null;
    }

    @Override
    public Void visitMethod(MethodTree methodTree, Node node) {
        int methodStart = ((JCTree.JCMethodDecl) methodTree).getStartPosition();
        int methodEnd = ((JCTree.JCMethodDecl) methodTree).getEndPosition(mEndPosTable);
        String contents = mContents.substring(methodStart, methodEnd);

        MethodDeclaration methodDeclaration = new MethodDeclaration();

        if (methodTree.getModifiers() != null) {
            scan(methodTree.getModifiers(), methodDeclaration);
        }

        if (methodTree.getTypeParameters() != null) {
            scan(methodTree.getTypeParameters(), methodDeclaration);
        }

        Type type = getType(methodTree.getReturnType());
        methodDeclaration.setType(type);

        String name = methodTree.getName().toString();
        methodDeclaration.setName(name);
        methodDeclaration.getName().setRange(getRangeForName(contents, name, methodStart));

        addNodeRange(methodTree, methodDeclaration);
        if (node instanceof ClassOrInterfaceDeclaration) {
            ((ClassOrInterfaceDeclaration) node).addMember(methodDeclaration);
        }
        return null;
    }

    @Override
    public Void visitTypeParameter(TypeParameterTree typeParameterTree, Node node) {
        if (node instanceof NodeWithTypeParameters) {
            TypeParameter typeParameter = new TypeParameter();
            int typeStart = ((JCTree.JCTypeParameter) typeParameterTree).getStartPosition();
            int typeEnd = ((JCTree.JCTypeParameter) typeParameterTree).getEndPosition(mEndPosTable);
            String contents = mContents.substring(typeStart, typeEnd);
            String name = typeParameterTree.getName().toString();
            typeParameter.setName(name);

            if (typeParameterTree.getBounds() != null) {
                NodeList<ClassOrInterfaceType> typeParameters = new NodeList<>();
                for (Tree bound : typeParameterTree.getBounds()) {
                    ClassOrInterfaceType type = getClassOrInterfaceType(bound);
                    typeParameters.add(type);
                }
                typeParameter.setTypeBound(typeParameters);
            }

            typeParameter.getName().setRange(getRangeForName(contents, name, typeStart));
            ((NodeWithTypeParameters<?>) node).addTypeParameter(typeParameter);
        }
        return null;
    }

    private ClassOrInterfaceType getClassOrInterfaceType(Tree tree) {
        ClassOrInterfaceType type = new ClassOrInterfaceType();
        if (tree instanceof IdentifierTree) {
            type.setName(((IdentifierTree) tree).getName().toString());
            type.getName().setRange(getTreeRange(tree));
            addNodeRange(tree, type);
        }
        if (tree instanceof ParameterizedTypeTree) {
            ParameterizedTypeTree parameterizedTypeTree = (ParameterizedTypeTree) tree;
            Type t = getType(parameterizedTypeTree.getType());

            NodeList<Type> typeArguments = new NodeList<>();
            for (Tree typeArgument : parameterizedTypeTree.getTypeArguments()) {
                Type typ = getType(typeArgument);
                typeArguments.add(typ);
            }
            if (t.isClassOrInterfaceType()) {
                type.setName(t.asClassOrInterfaceType().getName());
            }
            type.setTypeArguments(typeArguments);
        }
        return type;
    }

    private Type getType(Tree tree) {
        Type type;
        if (tree instanceof PrimitiveTypeTree) {
            type = getPrimitiveType((PrimitiveTypeTree) tree);
        } else if (tree instanceof IdentifierTree) {
            type = getClassOrInterfaceType(tree);
        } else if (tree instanceof WildcardTree) {
            JCTree.JCWildcard wildcardTree = (JCTree.JCWildcard) tree;
            WildcardType wildcardType = new WildcardType();
            Tree bound = wildcardTree.getBound();
            Type boundType = getType(bound);
            if (wildcardTree.kind.kind == BoundKind.EXTENDS) {
                wildcardType.setExtendedType((ReferenceType) boundType);
            } else {
                wildcardType.setSuperType((ReferenceType) boundType);
            }
            type = wildcardType;
        } else if (tree instanceof ParameterizedTypeTree) {
            type = getClassOrInterfaceType(tree);
        }
        else {
            type = new UnknownType();
        }
        type.setRange(getTreeRange(tree));
        return type;
    }

    private Type getPrimitiveType(PrimitiveTypeTree tree) {
        Type type;
        switch (tree.getPrimitiveTypeKind()) {
            case INT:
                type = PrimitiveType.intType();
                break;
            case BOOLEAN:
                type = PrimitiveType.booleanType();
                break;
            case LONG:
                type = PrimitiveType.longType();
                break;
            case SHORT:
                type = PrimitiveType.shortType();
                break;
            case CHAR:
                type = PrimitiveType.charType();
                break;
            case FLOAT:
                type = PrimitiveType.floatType();
                break;
            case VOID:
                type = new VoidType();
                break;
            default:
                type = new UnknownType();
        }
        type.setRange(getTreeRange(tree));
        return type;
    }

    @Override
    public Void visitModifiers(ModifiersTree modifiersTree, Node node) {
        // ModifiersTree also contains annotations, add them here
        if (node instanceof NodeWithModifiers) {
            NodeList<Modifier> modifiers = new NodeList<>();
            for (javax.lang.model.element.Modifier flag : modifiersTree.getFlags()) {
                Modifier modifier = JavaParserUtil.toModifier(flag);
                addNodeRange(modifiersTree, modifier);
                modifiers.add(modifier);
            }
            ((NodeWithModifiers<?>) node).setModifiers(modifiers);
        }
        if (node instanceof NodeWithAnnotations) {
            for (AnnotationTree annotation : modifiersTree.getAnnotations()) {
                AnnotationExpr expr = JavaParserUtil.toAnnotation(annotation);
                addNodeRange(annotation, expr);
                ((NodeWithAnnotations<?>) node).addAnnotation(expr);
            }
        }
        return null;
    }

    private Position getPosition(int start) {
        int line = mCallback.getLine(start) ;
        int column = mCallback.getColumn(start) ;
        return new Position(line, column);
    }

    /**
     * Creates a range object based on the position of the tree
     * @param tree Javac Tree
     * @return Range representing the position of the tree
     */
    private Range getTreeRange(Tree tree) {
        JCTree jcTree = ((JCTree) tree);
        int start = jcTree.getStartPosition();
        int end = jcTree.getEndPosition(mEndPosTable);
        if (end < 0) {
            end = start + 1;
        }
        return Range.range(getPosition(start), getPosition(end));
    }

    /**
     * Convenience method for getting the range of a name identifier
     * @param contents The contents of the tree where the name can be found
     * @param name Name to get the range for
     * @param startPos The start index of the tree from the main CompilationUnit
     * @return the range of the position of the name
     */
    private Range getRangeForName(String contents, String name, int startPos) {
        int startName = contents.indexOf(name);
        int endName = startName + name.length();
        int start = startPos + startName;
        int end = startPos + endName;
        return Range.range(getPosition(start), getPosition(end));
    }

    /**
     * Convenience method to add a range of the tree into its node
     * @param tree javac generated tree
     * @param node node on where to add the range
     */
    private void addNodeRange(Tree tree, NodeWithRange<?> node) {
        node.setRange(getTreeRange(tree));
    }
}
